use crate::constants::{D_MESG, D_PBLC};
use crate::ots::error::LmsOtsInvalidPrivateKey;
use crate::ots::modes::LmsOtsMode;
use crate::ots::public::VerifyingKey;
use crate::ots::signature::Signature;

use crate::types::Identifier;
use digest::{Digest, Output};
use hybrid_array::Array;
use rand_core::{CryptoRng, TryCryptoRng};
use signature::{Error, RandomizedSignerMut};
use zeroize::Zeroize;
//use std::mem::MaybeUninit;

#[derive(Debug)]
/// Opaque struct representing an LM-OTS private key. Does not implement
/// [Clone] because OTS keys are supposed to be one time use.
pub struct SigningKey<Mode: LmsOtsMode> {
    q: u32,
    id: Identifier,
    x: Array<Output<Mode::Hasher>, Mode::PLen>,
    valid: bool,
}

impl<Mode: LmsOtsMode> SigningKey<Mode> {
    /// Generate a private key, expanded pseudorandomly from a seed generated by `rng`.
    /// Uses the algorithm from appendix A <https://datatracker.ietf.org/doc/html/rfc8554#appendix-A>
    // a key part of this code working is the DerefMut impl for Array which we abuse in a similar manner to
    // generic_array::ArrayBuilder's internal implementation
    /// If LM-OTS is being used directly, q MUST be set to the all-zero value
    /// <https://datatracker.ietf.org/doc/html/rfc8554#section-4>
    pub fn new<R: CryptoRng>(q: u32, id: Identifier, rng: &mut R) -> Self {
        let mut seed: Array<u8, Mode::NLen> = Array::default();
        rng.fill_bytes(&mut seed);
        Self::new_from_seed(q, id, seed)
    }

    /// Returns a private key generated pseudorandomly from a seed
    /// according to Appendix A of <https://datatracker.ietf.org/doc/html/rfc8554#appendix-A>
    pub fn new_from_seed(q: u32, id: Identifier, seed: impl AsRef<[u8]>) -> Self {
        let seed = seed.as_ref();
        let x = Array::from_fn(|i| {
            Mode::Hasher::new()
                .chain_update(id)
                .chain_update(q.to_be_bytes())
                .chain_update((i as u16).to_be_bytes())
                .chain_update([0xff])
                .chain_update(seed)
                .finalize()
        });

        Self {
            q,
            id,
            x,
            valid: true,
        }
    }

    /// this implements algorithm 1 from <https://datatracker.ietf.org/doc/html/rfc8554#section-4.3>
    pub fn public(&self) -> VerifyingKey<Mode> {
        let mut hasher = Mode::Hasher::new()
            .chain_update(self.id)
            .chain_update(self.q.to_be_bytes())
            .chain_update(D_PBLC);

        let mut tmp = Output::<Mode::Hasher>::default();
        for i in 0..Mode::P {
            //let mut tmp = self.x[i].clone();
            tmp.clone_from(&self.x[i]);
            for j in 0..((1u32 << Mode::W) - 1) {
                Mode::Hasher::new()
                    .chain_update(self.id)
                    .chain_update(self.q.to_be_bytes())
                    .chain_update((i as u16).to_be_bytes())
                    .chain_update((j as u8).to_be_bytes())
                    .chain_update(&tmp)
                    .finalize_into(&mut tmp);
            }
            hasher.update(&tmp);
        }

        VerifyingKey {
            id: self.id,
            q: self.q,
            k: hasher.finalize(),
        }
    }

    /// `true` if the private key can be used for signing operations
    ///
    /// `false` if it has already been used
    pub fn is_valid(&self) -> bool {
        self.valid
    }
}

impl<Mode: LmsOtsMode> RandomizedSignerMut<Signature<Mode>> for SigningKey<Mode> {
    fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
        &mut self,
        rng: &mut R,
        msg: &[u8],
    ) -> Result<Signature<Mode>, Error> {
        if !self.valid {
            return Err(Error::from_source(LmsOtsInvalidPrivateKey {}));
        }

        // Generate the message randomizer C
        let mut c = <Output<Mode::Hasher>>::default();
        rng.try_fill_bytes(&mut c).map_err(|_| Error::new())?;

        // Q is the randomized message hash
        let q = Mode::Hasher::new()
            .chain_update(self.id)
            .chain_update(self.q.to_be_bytes())
            .chain_update(D_MESG)
            .chain_update(&c)
            .chain_update(msg)
            .finalize();

        // Y is the signature. We iterate over the message hash and checksum expanded into Winternitz coefficients
        let y = Mode::expand(&q).into_iter().enumerate().map(|(i, a)| {
            let a = a as u32;
            let mut tmp = self.x[i].clone();
            for j in 0..a {
                Mode::Hasher::new()
                    .chain_update(self.id)
                    .chain_update(self.q.to_be_bytes())
                    .chain_update((i as u16).to_be_bytes())
                    .chain_update((j as u8).to_be_bytes())
                    .chain_update(&tmp)
                    .finalize_into(&mut tmp);
            }
            tmp
        });
        let y = Array::from_iter(y);

        let sig = Signature { c, y };

        // zero out fields so we can't use the private key a second time
        self.q.zeroize();
        self.id.zeroize();
        self.x.zeroize();
        self.valid = false;

        Ok(sig)
    }
}
